"""
Detect variants in reads.
"""
import logging
from collections import defaultdict, Counter
from intervaltree import IntervalTree
from typing import Iterable, Iterator, List, Optional
from dataclasses import dataclass
import random

from .core import Read, ReadSet, NumericSampleIds
from .bam import SampleBamReader, MultiBamReader, BamReader
from .align import edit_distance, edit_distance_affine_gap
from ._variants import _iterate_cigar, _detect_alleles

logger = logging.getLogger(__name__)


class ReadSetError(Exception):
    pass


@dataclass
class AlleleProgress:
    progress: int = 0
    length: int = 0
    quality: int = 0
    matched: int = 0
    match_target: int = 0
    inserted: int = 0
    insert_target: int = 0
    deleted: int = 0
    delete_target: int = 0


class VariantProgress:
    def __init__(self, variant_id):
        self.variant_id = variant_id
        self.query_start = 0
        self.alleles = []

    def __iter__(self):
        for a in self.alleles:
            yield a

    def __len__(self):
        return len(self.alleles)

    def add_allele(self, matches, insertions, deletions):
        l = matches + insertions + deletions
        a = AlleleProgress(0, l, 0, 0, matches, 0, insertions, 0, deletions)
        self.alleles.append(a)

    def reset(self, query_start):
        self.query_start = query_start
        for a in self.alleles:
            a.progress, a.matched, a.inserted, a.deleted, a.quality = 0, 0, 0, 0, 0

    def get_resolved(self):
        return [i for i, a in enumerate(self.alleles) if a.progress == a.length]

    def get_pending(self):
        return [i for i, a in enumerate(self.alleles) if 0 <= a.progress < a.length]


class ReadSetReader:
    """
    Associate VCF variants with BAM reads.

    A VCF file contains variants, and a BAM file contain reads, but the
    information which read contains which variant is not available. This
    class re-discovers the variants in each read, using the
    knowledge in the VCF of where they should occur.
    """

    def __init__(
        self,
        paths: List[str],
        reference: Optional[str],
        numeric_sample_ids: NumericSampleIds,
        *,
        mapq_threshold: int = 20,
        overhang: int = 10,
        affine: int = False,
        gap_start: int = 10,
        gap_extend: int = 7,
        default_mismatch: int = 15,
        duplicates: bool = False,
    ):
        """
        paths -- list of BAM paths
        reference -- path to reference FASTA (can be None)
        numeric_sample_ids -- ??
        mapq_threshold -- minimum mapping quality
        overhang -- extend alignment by this many bases to left and right
        affine -- use affine gap costs
        gap_start, gap_extend, default_mismatch -- parameters for affine gap cost alignment
        duplicates -- read alignments marked as duplicate
        """
        self._mapq_threshold = mapq_threshold
        self._numeric_sample_ids = numeric_sample_ids
        self._use_affine = affine
        self._gap_start = gap_start
        self._gap_extend = gap_extend
        self._default_mismatch = default_mismatch
        self._overhang = overhang
        self._duplicates = duplicates
        self._paths = paths
        self._reader: BamReader
        if len(paths) == 1:
            self._reader = SampleBamReader(paths[0], reference=reference)
        else:
            self._reader = MultiBamReader(paths, reference=reference)

    @property
    def n_paths(self):
        return len(self._paths)

    def read(self, chromosome, variants, sample, reference, regions=None, args=None) -> ReadSet:
        """
        Detect alleles and return a ReadSet object containing reads representing
        the given variants.

        If a reference is provided (reference is not None), alleles are
        detected by re-aligning sections of the query to the REF and ALT
        sequence extended a few bases to the left and right.

        If reference is None, alleles are detected by inspecting the
        existing alignment (via the CIGAR).

        chromosome -- name of chromosome to work on
        variants -- list of vcf.VcfVariant objects
        sample -- name of sample to work on. If None, read group information is
            ignored and all reads in the file are used.
        reference -- reference sequence of the given chromosome (or None)
        regions -- list of start,end tuples (end can be None)
        """
        # Since variants are identified by position, positions must be unique.
        if __debug__ and variants:
            varposc = Counter(variant.position for variant in variants)
            pos, count = varposc.most_common()[0]
            assert count == 1, f"Position {pos} occurs more than once in variant list."

        alignments = self._usable_alignments(chromosome, sample, regions, args)
        reads = self._alignments_to_reads(alignments, variants, sample, reference)
        grouped_reads = self._group_paired_reads(reads)
        readset = self._make_readset_from_grouped_reads(grouped_reads)
        return readset

    @staticmethod
    def _make_readset_from_grouped_reads(groups: Iterable[List[Read]]) -> ReadSet:
        read_set = ReadSet()
        for group in groups:
            read_set.add(merge_reads(*group))
        return read_set

    @staticmethod
    def _group_paired_reads(reads: Iterable[Read]) -> Iterator[List[Read]]:
        """
        Group reads into paired-end read pairs. Uses name, source_id and sample_id
        as grouping key.

        TODO
        Grouping by name should be sufficient since the SAM spec states:
        "Reads/segments having identical QNAME are regarded to come from the same template."
        """
        groups = defaultdict(list)
        for read in reads:
            groups[(read.source_id, read.name, read.sample_id)].append(read)
        for group in groups.values():
            # if len(group) > 2:
            #     raise ReadSetError(
            #         f"Read name {group[0].name!r} occurs more than twice in the input file"
            #     )
            yield group

    def get_reference_span(self, cigar):
        ref_span = 0
        cigar_idx = 0
        op_len_str = ""
        while cigar_idx < len(cigar):
            if not cigar[cigar_idx].isdigit():
                if cigar[cigar_idx] in ["M", "D", "N", "=", "X"]:
                    ref_span += int(op_len_str)
                op_len_str = ""
            else:
                op_len_str += cigar[cigar_idx]
            cigar_idx += 1
        return ref_span

    def get_read_start(self, cigar, strand):
        op_len_str = ""
        if strand == "+":
            cigar_idx = 0
            while cigar[cigar_idx].isdigit():
                op_len_str += cigar[cigar_idx]
                cigar_idx += 1
            if cigar[cigar_idx] in ["H", "S"]:
                return int(op_len_str)
        elif cigar[-1] in ["H", "S"]:
            cigar_idx = -2
            while cigar[cigar_idx].isdigit():
                op_len_str += cigar[cigar_idx]
                cigar_idx -= 1
            return int(op_len_str[::-1])
        return 0


    def is_bad_aln(self, read, args):
        if not read.has_tag('NM'): return False
        CIGAR_OPS = ['M', 'I', 'D', 'N', 'S', 'H', 'P', '=', 'X', 'B']
        len_large_indels = 0
        for op, op_len in read.cigartuples:
            if CIGAR_OPS[op] in ["D", "I"] and op_len > 30:
                len_large_indels += op_len
        ref_span = read.reference_end - read.reference_start + 1
        discordance_ratio = (int(read.get_tag('NM')) - len_large_indels) / ref_span
        return discordance_ratio >= args.max_discordance

    def _usable_alignments(self, chromosome, sample, regions=None, args=None):
        """
        Retrieve usable (suficient mapping quality, not secondary etc.)
        alignments from the alignment file
        """
        if regions is None:
            regions = [(0, None)]
        for s, e in regions:
            seen2keep = defaultdict(list)
            seen2singleton = defaultdict(list)
            for alignment in self._reader.fetch(
                reference=chromosome, sample=sample, start=s, end=e
            ):
                # skip alignments
                if alignment.bam_alignment.mapping_quality < self._mapq_threshold: continue
                if alignment.bam_alignment.is_secondary or alignment.bam_alignment.is_unmapped: continue
                if alignment.bam_alignment.is_duplicate or alignment.bam_alignment.is_qcfail: continue
                if args.filter_bad_reads and self.is_bad_aln(alignment.bam_alignment, args): continue

                # ---------- paired short reads
                # split read pairs that map discordantly
                if alignment.bam_alignment.is_paired:
                    if alignment.bam_alignment.is_supplementary: continue
                    if (abs(alignment.bam_alignment.template_length) > args.max_isize
                        or alignment.bam_alignment.template_length == 0
                        or alignment.bam_alignment.mate_is_unmapped
                        or alignment.bam_alignment.next_reference_name != alignment.bam_alignment.reference_name
                        or alignment.bam_alignment.is_reverse == alignment.bam_alignment.mate_is_reverse):
                        alignment.bam_alignment.qname += "_" + "12"[alignment.bam_alignment.is_read2]
                    else: alignment.bam_alignment.qname += "_MP"  # join read pairs
                    yield alignment
                    continue

                # ------------- long reads
                if not alignment.bam_alignment.has_tag('SA'):
                    yield alignment
                    continue

                if alignment.bam_alignment.qname in seen2keep:
                    # already seen a partial alignment, yield if pre-selected
                    read_start = self.get_read_start(alignment.bam_alignment.cigarstring,
                                                     "-" if alignment.bam_alignment.is_reverse else "+")
                    if read_start in seen2keep[alignment.bam_alignment.qname]:
                        #alignment.bam_alignment.qname += "_" + "+-"[alignment.bam_alignment.is_reverse]
                        yield alignment
                        continue
                    elif read_start in seen2singleton[alignment.bam_alignment.qname]:
                        alignment.bam_alignment.qname += "_" + str(random.randint(0, 1000))
                        continue
                    else:
                        assert False

                # extract all partial alignments to this chromosome
                ref_span = alignment.bam_alignment.reference_end - alignment.bam_alignment.reference_start + 1
                curr_strand = "-" if alignment.bam_alignment.is_reverse else "+"
                curr_read_start = self.get_read_start(alignment.bam_alignment.cigarstring, curr_strand)
                partial_alignments = [(alignment.bam_alignment.mapping_quality, ref_span,
                                       alignment.bam_alignment.reference_start,
                                       alignment.bam_alignment.reference_end, curr_read_start, curr_strand)]
                for tag in alignment.bam_alignment.get_tag('SA').rstrip(";").split(';'):
                    entries = tag.split(',')
                    if entries[0] != chromosome: continue
                    ref_start = int(entries[1])
                    strand, cigar, mapq, _ = entries[2:]
                    if int(mapq) < self._mapq_threshold: continue
                    ref_end = ref_start + self.get_reference_span(cigar)
                    ref_span = ref_end - ref_start + 1
                    read_start = self.get_read_start(cigar, strand)
                    partial_alignments.append((int(mapq), ref_span, ref_start, ref_end, read_start, strand))

                selected_intervals = IntervalTree()
                selected_alns = []
                for aln in sorted(partial_alignments, key=lambda element: (element[0], element[1]), reverse=True):
                    if not selected_intervals.overlaps(aln[2], aln[3]):
                        selected_intervals.addi(aln[2], aln[3], aln)
                        selected_alns.append(aln)
                    select = True
                    for c in selected_intervals.overlap(aln[2], aln[3]):
                        overlap = min(aln[3], c.data[3]) - max(aln[2], c.data[2])
                        if overlap > args.read_overlap_th:
                            select = False
                            seen2singleton[alignment.bam_alignment.qname].append(aln[4])
                            break
                    if select:
                        selected_intervals.addi(aln[2], aln[3], aln)
                        selected_alns.append(aln)

                aln_sorted_by_pos = sorted(selected_alns, key=lambda element: element[2])
                seen2keep[alignment.bam_alignment.qname].append(aln_sorted_by_pos[0][4])
                for i in range(1, len(aln_sorted_by_pos)):
                    if abs(aln_sorted_by_pos[i-1][3] - aln_sorted_by_pos[i][2]) < args.supp_distance_th:
                        seen2keep[alignment.bam_alignment.qname].append(aln_sorted_by_pos[i][4])
                    else:
                        seen2singleton[alignment.bam_alignment.qname].append(aln_sorted_by_pos[i][4])
                if curr_read_start not in seen2keep[alignment.bam_alignment.qname]:
                    alignment.bam_alignment.qname += "_" + str(random.randint(0, 1000))
                yield alignment

    def has_reference(self, chromosome):
        return self._reader.has_reference(chromosome)

    def _alignments_to_reads(self, alignments, variants, sample, reference):
        """
        Convert BAM alignments to Read objects.

        If reference is not None, alleles are detected through re-alignment.

        Yield Read objects.
        """
        # FIXME hard-coded zero
        numeric_sample_id = 0 if sample is None else self._numeric_sample_ids[sample]
        if reference is not None:
            # Copy the pyfaidx.FastaRecord into a str for faster access
            reference = reference[:]
            normalized_variants = variants
        else:
            normalized_variants = [variant.normalized() for variant in variants]

        # Create allele progress trackers once, instead of doing it for every read again
        if reference is None:
            # Discard overlapping and duplicate-positioned variants for more efficient iteration
            valid_variant_ids = self.detect_non_overlapping_variants(normalized_variants)
            valid_positions = [normalized_variants[j].position for j in valid_variant_ids]
            var_progress = [
                self.build_var_progress(normalized_variants, j) for j in valid_variant_ids
            ]
            var_progress.sort(key=lambda x: x.variant_id)

        i = 0  # index into variants (reference) or variant progresses (no reference)

        for alignment in alignments:
            try:
                barcode = alignment.bam_alignment.get_tag("BX")
            except KeyError:
                barcode = ""

            read = Read(
                alignment.bam_alignment.qname,
                alignment.bam_alignment.mapq,
                alignment.source_id,
                numeric_sample_id,
                alignment.bam_alignment.reference_start,
                barcode,
                "+-"[alignment.bam_alignment.is_reverse]
            )

            if reference is None:
                # Skip variant progress objects that are to the left of this read
                while (
                    i < len(valid_positions)
                    and valid_positions[i] < alignment.bam_alignment.reference_start
                ):
                    i += 1
                detected = _detect_alleles(
                    normalized_variants, var_progress, i, alignment.bam_alignment
                )
            else:
                # Skip variants that are to the left of this read
                while (
                    i < len(normalized_variants)
                    and normalized_variants[i].position < alignment.bam_alignment.reference_start
                ):
                    i += 1
                detected = self.detect_alleles_by_alignment(
                    variants,
                    i,
                    alignment.bam_alignment,
                    reference,
                    self._overhang,
                    self._use_affine,
                    self._gap_start,
                    self._gap_extend,
                    self._default_mismatch,
                )

            for j, allele, quality in detected:
                read.add_variant(variants[j].position, allele, quality)
            if read:  # At least one variant covered and detected
                yield read

    def detect_non_overlapping_variants(self, variants):
        """
        Checks for deletion variants overlapping other variants and for variants with duplicate
        positions. Returns a set of variant indices, which are conflict with another variant and
        should not be considered for allele detection.

        variants -- list of variants (VcfVariant objects)
        """
        j = 0
        conflicting = set()
        seen_pos = set()
        while j < len(variants):
            v = variants[j]
            if v.position in seen_pos:
                conflicting.add(j)
                j += 1
                continue
            else:
                seen_pos.add(v.position)
            ref = len(v.reference_allele)
            max_del = max(ref - len(alt) for alt in v.get_alt_allele_list())
            if max_del > 0:
                # at least one alt allele shorter than ref allele exists:
                deletion_end = v.position + ref
                if j + 1 < len(variants) and variants[j + 1].position < deletion_end:
                    # at least one follow-up variant overlaps the deletion
                    conflicting.add(j)
                    while j + 1 < len(variants) and variants[j + 1].position < deletion_end:
                        j += 1
                        conflicting.add(j)
            j += 1
        return [j for j in range(len(variants)) if j not in conflicting]

    def build_var_progress(self, variants, j):
        """
        Creates an object for tracking match progress of the j-th variant. Each object contains
        the variant id and lengths for every allele.
        """
        v = VariantProgress(j)
        ref_len = len(variants[j].reference_allele)
        v.add_allele(len(variants[j].reference_allele), 0, 0)
        for i, alt in enumerate(variants[j].get_alt_allele_list()):
            alt_len = len(alt)
            match_target = min(ref_len, alt_len)
            ins_target = max(0, len(alt) - ref_len)
            del_target = max(0, ref_len - len(alt))
            v.add_allele(match_target, ins_target, del_target)
        return v

    @staticmethod
    def split_cigar(cigar, i, consumed):
        """
        Split a CIGAR into two parts. i and consumed describe the split position.
        i is the element of the cigar list that should be split, and consumed says
        at how many operations to split within that element.

        The CIGAR is given as a list of (operation, length) pairs.

        i -- split at this index in cigar list
        consumed -- how many cigar ops at cigar[i] are to the *left* of the
            split position

        Return a tuple (left, right).

        Example:
        Assume the cigar is 3M 1D 6M 2I 4M.
        With i == 2 and consumed == 5, the cigar is split into
        3M 1D 5M and 1M 2I 4M.
        """
        middle_op, middle_length = cigar[i]
        assert consumed <= middle_length
        if consumed > 0:
            left = cigar[:i] + [(middle_op, consumed)]
        else:
            left = cigar[:i]
        if consumed < middle_length:
            right = [(middle_op, middle_length - consumed)] + cigar[i + 1 :]
        else:
            right = cigar[i + 1 :]
        return left, right

    @staticmethod
    def cigar_prefix_length(cigar, reference_bases):
        """
        Given a prefix of length reference_bases relative to the reference, how
        long is the prefix of the read? In other words: If reference_bases on
        the reference are consumed, how many bases on the query does that
        correspond to?

        If the position is within or at the end of an insertion (which do not
        consume bases on the reference), then the number of bases up to the
        beginning of the insertion is reported.

        Return a pair (reference_bases, query_bases) where the value for
        reference_bases may be smaller than the requested one if the CIGAR does
        not cover enough reference bases.

        Reference skips (N operators) are treated as the end of the read. That
        is, no positions beyond a reference skip are reported.
        """
        ref_pos = 0
        query_pos = 0
        for op, length in cigar:
            if op in (0, 7, 8):  # M, X, =
                ref_pos += length
                query_pos += length
                if ref_pos >= reference_bases:
                    return (reference_bases, query_pos + reference_bases - ref_pos)
            elif op == 2:  # D
                ref_pos += length
                if ref_pos >= reference_bases:
                    return (reference_bases, query_pos)
            elif op == 1:  # I
                query_pos += length
            elif op == 4 or op == 5:  # soft or hard clipping
                pass
            elif op == 3:  # N
                # Always stop at reference skips
                return (reference_bases, query_pos)
            else:
                assert False, "unknown CIGAR operator"
        assert ref_pos < reference_bases
        return (ref_pos, query_pos)

    @staticmethod
    def realign(
        variant,
        bam_read,
        cigartuples,
        i,
        consumed,
        query_pos,
        reference,
        overhang,
        use_affine,
        gap_start,
        gap_extend,
        default_mismatch,
    ):
        """
        Realign a read to the two alleles of a single variant.
        i and consumed describe where to split the cigar into a part before the
        variant position and into a part starting at the variant position, see split_cigar().

        variant -- VcfVariant
        bam_read -- the AlignedSegment
        cigartuples -- the AlignedSegment.cigartuples property (accessing it is expensive, so re-use it)
        i, consumed -- see split_cigar method
        query_pos -- index of the query base that is at the variant position
        reference -- the reference as a str-like object (full chromosome)
        overhang -- extend alignment by this many bases to left and right
        use_affine -- if true, use affine gap costs for realignment
        gap_start, gap_extend -- if affine_gap=true, use these parameters for affine gap cost alignment
        default_mismatch -- if affine_gap=true, use this as mismatch cost in case no base qualities are in bam
        """
        # Do not process symbolic alleles like <DEL>, <DUP>, etc.
        if any(alt.startswith("<") for alt in variant.get_alt_allele_list()):
            return None, None

        left_cigar, right_cigar = ReadSetReader.split_cigar(cigartuples, i, consumed)

        left_ref_bases, left_query_bases = ReadSetReader.cigar_prefix_length(
            left_cigar[::-1], overhang
        )
        right_ref_bases, right_query_bases = ReadSetReader.cigar_prefix_length(
            right_cigar, len(variant.reference_allele) + overhang
        )

        assert variant.position - left_ref_bases >= 0
        assert variant.position + right_ref_bases <= len(reference)

        query = bam_read.query_sequence[
            query_pos - left_query_bases : query_pos + right_query_bases
        ]
        pos = variant.position
        left_pad = reference[pos - left_ref_bases : pos]
        right_pad = reference[pos + len(variant.reference_allele) : pos + right_ref_bases]
        padded_alleles = [reference[pos - left_ref_bases : pos + right_ref_bases]]
        for alt in variant.get_alt_allele_list():
            padded_alleles.append(left_pad + alt + right_pad)

        if use_affine:
            assert gap_start is not None
            assert gap_extend is not None
            assert default_mismatch is not None

            # get base qualities if present (to be used as mismatch costs)
            base_qualities = [default_mismatch] * len(query)
            # if bam_read.query_qualities != None:
            #    base_qualities = bam_read.query_qualities[query_pos-left_query_bases:query_pos+right_query_bases]

            # compute edit dist. with affine gap costs using base qual. as mismatch cost
            distances = [
                (i, edit_distance_affine_gap(query, allele, base_qualities, gap_start, gap_extend))
                for i, allele in enumerate(padded_alleles)
            ]
            distances.sort(key=lambda x: x[1])
            base_qual_score = 30 #distances[0][1] - distances[1][1]
        else:
            distances = [
                (i, edit_distance(query, allele)) for i, allele in enumerate(padded_alleles)
            ]
            distances.sort(key=lambda x: x[1])
            base_qual_score = 30

        if distances[0][1] < distances[1][1]:
            return distances[0][0], base_qual_score  # detected REF
        else:
            return None, None  # cannot decide

    @staticmethod
    def detect_alleles_by_alignment(
        variants,
        j,
        bam_read,
        reference,
        overhang=10,
        use_affine=False,
        gap_start=None,
        gap_extend=None,
        default_mismatch=None,
    ):
        """
        Detect which alleles the given bam_read covers. Detect the correct
        alleles of the variants that are covered by the given bam_read.

        Yield tuples (position, allele, quality).

        variants -- list of variants (VcfVariant objects)
        j -- index of the first variant (in the variants list) to check
        """
        # Accessing bam_read.cigartuples is expensive, do it only once
        cigartuples = bam_read.cigartuples

        # For the same reason, the following check is here instad of
        # in the _usable_alignments method
        if not cigartuples:
            return

        for index, i, consumed, query_pos in _iterate_cigar(variants, j, bam_read, cigartuples):
            allele, quality = ReadSetReader.realign(
                variants[index],
                bam_read,
                cigartuples,
                i,
                consumed,
                query_pos,
                reference,
                overhang,
                use_affine,
                gap_start,
                gap_extend,
                default_mismatch,
            )
            num_alts = len(variants[index].get_alt_allele_list())
            if allele in range(num_alts + 1):
                yield (index, allele, quality)  # TODO quality???
            else:
                yield (index, -1, -1)

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def close(self):
        self._reader.close()


def merge_two_reads(read1: Read, read2: Read) -> Read:
    """
    Merge two reads *that belong to the same haplotype* (such as the two
    ends of a paired-end read) into a single Read. Overlaps are allowed.
    """
    assert read1.is_sorted()
    assert read2.is_sorted()

    if read2:
        result = Read(
            read1.name,
            read1.mapqs[0],
            read1.source_id,
            read1.sample_id,
            read1.reference_start,
            read1.BX_tag,
            read1.strand,
        )
        result.add_mapq(read2.mapqs[0])
    else:
        return read1

    i1 = 0
    i2 = 0

    def add1():
        result.add_variant(read1[i1].position, read1[i1].allele, read1[i1].quality)

    def add2():
        result.add_variant(read2[i2].position, read2[i2].allele, read2[i2].quality)

    while i1 < len(read1) or i2 < len(read2):
        if i1 == len(read1):
            add2()
            i2 += 1
            continue
        if i2 == len(read2):
            add1()
            i1 += 1
            continue
        variant1 = read1[i1]
        variant2 = read2[i2]
        if variant2.position < variant1.position:
            add2()
            i2 += 1
        elif variant2.position > variant1.position:
            add1()
            i1 += 1
        else:
            # Variant on self-overlapping read pair
            assert read1[i1].position == read2[i2].position
            # If both alleles agree, merge into single variant and add up qualities
            if read1[i1].allele == -1:
                add2()
            elif read2[i2].allele == -1:
                add1()
            elif read1[i1].allele == read2[i2].allele:
                quality = read1[i1].quality + read2[i2].quality
                result.add_variant(read1[i1].position, read1[i1].allele, quality)
            #else: # viq: don't allow overlap
                # Otherwise, take variant with highest base quality and discard the other.
                #add1()
                #if read1[i1].quality >= read2[i2].quality:
                #    add1()
                #else:
                #    add2()
            i1 += 1
            i2 += 1
    return result


def merge_reads(*reads: Read) -> Read:
    """
    Merge multiple reads that belong to the same haplotype into a single Read.

    If the iterable is empty, a ValueError is raised.

    This 'naive' version just calls merge_two_reads repeatedly on all the reads.

    # TODO
    # The actual challenge is dealing with conflicts in variants covered by
    # more than one read. A solution would be to not merge if there are any
    # (or too many) conflicts and let the main algorithm deal with it.
    """
    it = iter(reads)
    try:
        read = next(it)
    except StopIteration:
        raise ValueError("no reads to merge")
    assert read.is_sorted()
    for partner in it:
        read = merge_two_reads(read, partner)
    return read

